"""Decentralized Distributed PPO implementation.

Unlike APPO or PPO, learning is no longer done centralized in the trainer
process. Instead, gradients are computed remotely on each rollout worker and
all-reduced to sync them at each mini-batch. This allows each worker's GPU
to be used both for sampling and for training.

DD-PPO should be used if you have envs that require GPUs to function, or have
a very large model that cannot be effectively optimized with the GPUs available
on a single machine (DD-PPO allows scaling to arbitrary numbers of GPUs across
multiple nodes, unlike PPO/APPO which is limited to GPUs on a single node).

Paper reference: https://arxiv.org/abs/1911.00357
Note that unlike the paper, we currently do not implement straggler mitigation.
"""

import logging
import time

import ray
from ray.util.iter import LocalIterator
from ray.rllib.agents.ppo import ppo
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.optimizers import TorchDistributedDataParallelOptimizer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
    STEPS_TRAINED_COUNTER, LEARNER_INFO, LEARN_ON_BATCH_TIMER
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.utils.sgd import do_minibatch_sgd

logger = logging.getLogger(__name__)

# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, {
    # During the sampling phase, each rollout worker will collect a batch
    # `rollout_fragment_length * num_envs_per_worker` steps in size.
    "rollout_fragment_length": 100,
    # Vectorize the env (should enable by default since each worker has a GPU).
    "num_envs_per_worker": 5,
    # During the SGD phase, workers iterate over minibatches of this size.
    # The effective minibatch size will be `sgd_minibatch_size * num_workers`.
    "sgd_minibatch_size": 50,
    # Number of SGD epochs per optimization round.
    "num_sgd_iter": 10,
    # Download weights between each training step. This adds a bit of overhead
    # but allows the user to access the weights from the trainer.
    "keep_local_weights_in_sync": True,

    # *** WARNING: configs below are DDPPO overrides over PPO; you
    #     shouldn't need to adjust them. ***
    "use_pytorch": True,  # DDPPO requires PyTorch distributed.
    "num_gpus": 0,  # Learning is no longer done on the driver process, so
                    # giving GPUs to the driver does not make sense!
    "num_gpus_per_worker": 1,  # Each rollout worker gets a GPU.
    "truncate_episodes": True,  # Require evenly sized batches. Otherwise,
                                # collective allreduce could fail.
    "train_batch_size": -1,  # This is auto set based on sample batch size.
})
# __sphinx_doc_end__
# yapf: enable


def validate_config(config):
    if config["train_batch_size"] == -1:
        # Auto set.
        config["train_batch_size"] = (
            config["rollout_fragment_length"] * config["num_envs_per_worker"])
    else:
        raise ValueError(
            "Set rollout_fragment_length instead of train_batch_size "
            "for DDPPO.")
    if not config["use_pytorch"]:
        raise ValueError(
            "Distributed data parallel is only supported for PyTorch")
    if config["num_gpus"]:
        raise ValueError(
            "When using distributed data parallel, you should set "
            "num_gpus=0 since all optimization "
            "is happening on workers. Enable GPUs for workers by setting "
            "num_gpus_per_worker=1.")
    if config["batch_mode"] != "truncate_episodes":
        raise ValueError(
            "Distributed data parallel requires truncate_episodes "
            "batch mode.")
    ppo.validate_config(config)


def make_distributed_allreduce_optimizer(workers, config):
    return TorchDistributedDataParallelOptimizer(
        workers,
        expected_batch_size=config["rollout_fragment_length"] *
        config["num_envs_per_worker"],
        num_sgd_iter=config["num_sgd_iter"],
        sgd_minibatch_size=config["sgd_minibatch_size"],
        standardize_fields=["advantages"])


# Experimental distributed execution impl; enable with "use_exec_api": True.
def execution_plan(workers, config):
    rollouts = ParallelRollouts(workers, mode="raw")

    # Sync up the weights. In principle we don't need this, but it doesn't
    # add too much overhead and handles the case where the user manually
    # updates the local weights.
    if config["keep_local_weights_in_sync"]:
        weights = ray.put(workers.local_worker().get_weights())
        for e in workers.remote_workers():
            e.set_weights.remote(weights)

    # Setup the distributed processes.
    if not workers.remote_workers():
        raise ValueError("This optimizer requires >0 remote workers.")
    ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
    port = ray.get(workers.remote_workers()[0].find_free_port.remote())
    address = "tcp://{ip}:{port}".format(ip=ip, port=port)
    logger.info("Creating torch process group with leader {}".format(address))

    # Get setup tasks in order to throw errors on failure.
    ray.get([
        worker.setup_torch_data_parallel.remote(
            address, i, len(workers.remote_workers()), backend="gloo")
        for i, worker in enumerate(workers.remote_workers())
    ])
    logger.info("Torch process group init completed")

    # This function is applied remotely on each rollout worker.
    def train_torch_distributed_allreduce(batch):
        expected_batch_size = (
            config["rollout_fragment_length"] * config["num_envs_per_worker"])
        this_worker = get_global_worker()
        assert batch.count == expected_batch_size, \
            ("Batch size possibly out of sync between workers, expected:",
             expected_batch_size, "got:", batch.count)
        logger.info("Executing distributed minibatch SGD "
                    "with epoch size {}, minibatch size {}".format(
                        batch.count, config["sgd_minibatch_size"]))
        info = do_minibatch_sgd(batch, this_worker.policy_map, this_worker,
                                config["num_sgd_iter"],
                                config["sgd_minibatch_size"], ["advantages"])
        return info, batch.count

    # Have to manually record stats since we are using "raw" rollouts mode.
    class RecordStats:
        def _on_fetch_start(self):
            self.fetch_start_time = time.perf_counter()

        def __call__(self, items):
            for item in items:
                info, count = item
                metrics = LocalIterator.get_metrics()
                metrics.counters[STEPS_SAMPLED_COUNTER] += count
                metrics.counters[STEPS_TRAINED_COUNTER] += count
                metrics.info[LEARNER_INFO] = info
            # Since SGD happens remotely, the time delay between fetch and
            # completion is approximately the SGD step time.
            metrics.timers[LEARN_ON_BATCH_TIMER].push(time.perf_counter() -
                                                      self.fetch_start_time)

    train_op = (
        rollouts.for_each(train_torch_distributed_allreduce)  # allreduce
        .batch_across_shards()  # List[(grad_info, count)]
        .for_each(RecordStats()))

    # Sync down the weights. As with the sync up, this is not really
    # needed unless the user is reading the local weights.
    if config["keep_local_weights_in_sync"]:

        def download_weights(item):
            workers.local_worker().set_weights(
                ray.get(workers.remote_workers()[0].get_weights.remote()))
            return item

        train_op = train_op.for_each(download_weights)

    # In debug mode, check the allreduce successfully synced the weights.
    if logger.isEnabledFor(logging.DEBUG):

        def check_sync(item):
            weights = ray.get(
                [w.get_weights.remote() for w in workers.remote_workers()])
            sums = []
            for w in weights:
                acc = 0
                for p in w.values():
                    for k, v in p.items():
                        acc += v.sum()
                sums.append(float(acc))
            logger.debug("The worker weight sums are {}".format(sums))
            assert len(set(sums)) == 1, sums

        train_op = train_op.for_each(check_sync)

    return StandardMetricsReporting(train_op, workers, config)


DDPPOTrainer = ppo.PPOTrainer.with_updates(
    name="DDPPO",
    default_config=DEFAULT_CONFIG,
    make_policy_optimizer=make_distributed_allreduce_optimizer,
    execution_plan=execution_plan,
    validate_config=validate_config)
